from torchvision import datasets
from typing import Optional, Callable, Tuple, Any
import torch


class Faces(datasets.ImageFolder):
    def __init__(self, root: str, batch_size: int = 1, split: int = 0, test_time: bool = False, steps_per_example: int = 1, 
                 minimizer = None, transform: Optional[Callable] = None, single_crop: bool = False, start_index: int = 0):
        super().__init__(root=root, transform=transform)
        self.batch_size = batch_size
        self.minimizer = minimizer
        self.steps_per_example = steps_per_example
        self.single_crop = single_crop
        self.test_time = test_time
        self.start_index = start_index
        self.samples = sorted(self.samples, key=lambda sample: int(sample[0].split('/')[-1].split('_')[0]))
        self.samples = self.samples[split*5062:(1+split)*5062]
        # if split == 0:
        #     self.samples = [sample for sample in self.samples if int(sample[0].split('/')[-1].split('_')[0]) <= 1940]
        # else:
        #     # Super hacky
        #     self.samples = [sample for sample in self.samples if int(sample[0].split('/')[-1].split('_')[0]) >= 1970]
        print(f'Sample size is {len(self.samples)}')

    def __len__(self):
        if self.test_time:
            mult = self.steps_per_example * self.batch_size
            mult *= (super().__len__() if self.minimizer is None else len(self.minimizer)) 
            return mult
        return len(self.samples)
    
    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        """
        Args:
            index (int): Index

        Returns:
            tuple: (sample, target) where target is class_index of the target class.
        """
        if self.test_time:
            real_index = (index // self.steps_per_example) + self.start_index
        else:
            real_index = index
        if self.minimizer is not None:
            real_index = self.minimizer[real_index]
        path, target = self.samples[real_index]
        sample = self.loader(path)
        if self.transform is not None and not self.single_crop:
            if not self.test_time:
                samples = self.transform(sample)
            else:    
                samples = torch.stack([self.transform(sample) for i in range(self.batch_size)], axis=0)
        elif self.transform and self.single_crop:
            s = self.transform(sample)
            if self.test_time:
                samples = torch.stack([s for i in range(self.batch_size)], axis=0)
            else:
                samples = s
        if self.target_transform is not None:
            target = self.target_transform(target)

        return samples, target